"""
来自论文[1] 许利恒.基于椭圆特征的陨石坑检测与识别方法研究[D].北京:北京航空航天大学,2022

按照论文的方法复现，关于从四点金字塔拓展至整幅陨石坑的算法，这里使用了姿态解算的结果作为重投影平差的依据，不再使用导航坑对的不变量计算，增加了对陨石坑的筛选算法
"""

import numpy as np
from ...base import Matching
from utils.ellipse import center_ellipse, radius_ellipse
from itertools import combinations
from scipy.optimize import linear_sum_assignment
from utils.pose import pose_calculate
import cv2
import pandas as pd


class TriadHashVote(Matching):
    def __init__(self, catalog_dir, **kwargs):
        super().__init__(catalog_dir, **kwargs)
        self.descriptor = pd.Series(self.descriptor.item())

    def triad_descriptor(
        self, C1, C2, C3, U1, U2, U3, th, *args, **kwargs
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray[bool]]:
        """
        从陨石坑的参数中计算描述子，C1和C2可以是含有不确定度的椭圆参数
        TODO 必须支持向量化操作
        Output:
          这里的输出一定是编码后的Hash值！
        """
        raise NotImplementedError

    def hash_search(self, qs, ijk, index, vote_th=3):
        raise NotImplementedError

    def identify(self, params, uncertainty, *args, th=3, confidence=0.8, **kwargs):
        """
        输入应该越接近于真实的检测输出越好。因此输入应当是一堆椭圆的拟合参数，输出应当是匹配确定的结果，即椭圆ID的列表。
        Arguments:
            params (np.ndarray) : ellipse parameters for (x^2, xy, y^2, x, y, 1)
        Returns:
            list : matched ellipse ID
        """
        # 将陨石坑按直径排序
        params = np.array(params)
        diameters = np.mean(radius_ellipse(params), axis=0)
        idx = np.argsort(-diameters)
        if len(idx) < 4:
            return None
        # 似乎也可以不用构建全部二元组
        # 从直径最大、且二元组最多的陨石坑对开始构建金字塔
        i, j, k = np.array(list(combinations(idx, 3))).T
        # 批量算出金字塔的三元组
        qs, index, valid = self.triad_descriptor(
            params[i],
            params[j],
            params[k],
            uncertainty[i],
            uncertainty[j],
            uncertainty[k],
            th=2,
        )
        ijk = np.array((i, j, k))[:, valid]
        ijk = np.take_along_axis(ijk, index[:, valid], axis=0)
        # 变成四个三元组，即金字塔的结构
        result = self.hash_search(qs, ijk)
        if result is not None:
            success, result = self.extend_identify(
                params, uncertainty, *result, th=th, confidence=confidence, **kwargs
            )
            if success:
                return result
        return None

    def hash_search(self, qs, ijks):
        votes = np.zeros((self.catalog.shape[0], self.catalog.shape[0]), dtype=np.int32)
        ijks = ijks.T
        # 取消投票法，改为使用金字塔法
        v = self.descriptor.reindex(qs.flatten())
        v = pd.DataFrame(v.values.reshape(qs.shape[0], -1))
        # 合并行
        v = v.apply(lambda x: x.dropna().tolist(), axis=0)
        v_num = v.apply(len)
        ind = v_num > 0
        if not ind.any():
            return None
        # 为防止某些表位置处含有两个以上的值，即v_num测出来的值比实际小，这里需要连接后再次计算
        v = v[ind].apply(np.concatenate)
        v_num = v.apply(len)
        ## 处理其中冗余项，即对原始坐标加重复
        IJK = np.concatenate(v.tolist()).flatten()
        ijk = ijks[ind].repeat(v_num[ind], axis=0).flatten()
        ## 投票
        ## 投票时保证同一坐标不会重复投票
        vals, counts = np.unique((ijk, IJK), axis=1, return_counts=True)
        votes[vals[0], vals[1]] += counts
        ## 用匈牙利算法完成识别分配
        ijkl, IJKL = linear_sum_assignment(-votes)
        # 最优分配
        ## 取得票数最高的前四个
        candidate = votes[ijkl, IJKL]
        ind = np.argsort(-candidate)[:8]
        ind = ind[candidate[ind] > 0]
        if ind.shape[0] < 4:
            return None
        else:
            return ijkl[ind], IJKL[ind]

    def extend_identify(
        self,
        params,
        uncertainty,
        ijkl,
        IJKL,
        K,
        th: float,
        confidence: float,
        dist_th=5,
        **kwargs,
    ):
        """
        TODO 用增量式重投影方法完成对全幅陨石坑的匹配，具体的细节如下：
        1. 利用现有四点ijkl估计出一个摄像机矩阵P（TODO 有可能估计失败）
        2. 根据当前的四点ijkl，设置某个超参数r，搜索四点周围r个点构成候选重投影点，利用P重投影至图像平面，记为Qr
        3. 利用匈牙利算法，对当前的Qr+4个重投影点与当前图像上的全部陨石坑中心点（可能多于Qr+4）作匹配（TODO 算法输出差过大的点应当舍去）。
        4. 对于每个匹配上的陨石坑中心点，计算其不变量以排除潜在的噪声点。
        4. 匹配的结果作为新的ijkl，重新估计P，重复步骤2-3，直到全部匹配上或者达到最大迭代次数
        """
        # 返回第一符合要求的值
        # 将金字塔识别结果拓展至整个图像
        if len(ijkl) < 4:
            return False, None
        cnts = np.array(center_ellipse(params)).T
        ijkl, IJKL = self.increment_hungary(params, ijkl, IJKL, K, dist_th)
        if len(ijkl) < 4:
            return False, None
            # 重投影验证
        points_3d = self.catalog[IJKL, :]
        obse_2d = cnts[ijkl, :]
        R, T = pose_calculate(K, points_3d, obse_2d)
        if R is None:
            return False, None
        uv = cv2.projectPoints(points_3d, R, T, K, np.zeros(4))[0].squeeze()
        ind = np.linalg.norm(uv - obse_2d, axis=1) < th
        if ind.mean() >= confidence:
            return True, (ijkl[ind], IJKL[ind])
        else:
            return False, None

    def increment_hungary(
        self,
        params,
        uncertainty,
        ijkl: tuple,
        IJKL: tuple,
        K,
        dist_th=2.0,
    ):
        """
        用增量式重投影方法完成对全幅陨石坑的匹配，具体的细节如下：
        1. 利用现有四点ijkl估计出一个摄像机矩阵P（TODO 有可能估计失败）
        2. 根据当前的四点ijkl，设置某个超参数r，搜索四点周围r个点构成候选重投影点，利用P重投影至图像平面，记为Qr
        3. 利用匈牙利算法，对当前的Qr+4个重投影点与当前图像上的全部陨石坑中心点（可能多于Qr+4）作匹配；
        4. 对于每个匹配上的陨石坑中心点，计算其不变量以排除潜在的噪声点。
        4. 匹配的结果作为新的ijkl，重新估计P，重复步骤2-3，直到全部匹配上或者达到最大迭代次数
        Arguments:
            ijkl (tuple) : 当前图像每个被匹配的陨石坑序号
            IJKL (tuple) : 当前图像每个被匹配的陨石坑在目录中的序号
            K (np.ndarray) : 相机内参矩阵
            dist_th (float) : 匹配距离阈值
        """
        # 估计P
        points_3d = self.catalog[IJKL, :]
        cnts = np.array(center_ellipse(params)).T
        obse_2d = cnts[ijkl, :]
        R, T = pose_calculate(K, points_3d, obse_2d)
        if R is None:
            return [], []
        # 重投影
        uv = cv2.projectPoints(self.catalog, R, T, K, np.zeros(4))[0].squeeze()
        # 把当前可能位于图像平面内的全部点都加入至候选列表中
        ind = (
            (uv[:, 0] < K[0, 2] * 2)
            & (uv[:, 0] >= 0)
            & (uv[:, 1] < K[1, 2] * 2)
            & (uv[:, 1] >= 0)
        )
        # 得到全部候选中心点列表
        cata_3d = uv[ind]
        cata_id = np.arange(self.catalog.shape[0])[ind]
        # 获取当前全部图像中心点
        obse_id = np.arange(len(cnts))
        # 匈牙利算法
        dist_matrix = np.linalg.norm(cata_3d[:, np.newaxis] - cnts, axis=2)
        # 使用匈牙利算法计算最优分配
        row_ind, col_ind = linear_sum_assignment(dist_matrix)
        # 最优分配
        cata_id = cata_id[row_ind]
        obse_id = obse_id[col_ind]
        # 计算匹配误差，本误差也将作为筛选真实陨石坑和虚假陨石坑的依据
        error = dist_matrix[row_ind, col_ind]
        # 计算不变量验证
        ## 具体的步骤应当包括：
        ## 1. 根据匈牙利算法匹配出的距离远近和设计的超参数dist_th，选出符合匹配距离范围内的候选陨石坑
        ## 2. 全部候选陨石坑两两作对，算出不变量以及配对
        ## 3. 根据配对关系，给出导航坑表内的不变量对
        # 挑选匹配距离范围内的陨石坑
        ind = error < dist_th
        cata_id = cata_id[ind]
        obse_id = obse_id[ind]
        return obse_id, cata_id
